#!/usr/bin/python

# Argumets:
# 1) Echo-antiecho (EAE) or states mode (S)
# 2) p0 in direct dimension
# 3) p0 in indirect dimension
# 4) sampling schedule file (Sampling.txt)
# 5) Algorithm (IST/IRLS)
# 6) Number of iterations
# 7) Virtual echo (y/n)

# ./2D_varian_direct_chunks.py EAE 0 0 Sampling.txt IST 100 n

from numpy import *
import nmrglue as ng
import sys
import os
import struct

print "\nReading data...",
dic, data = ng.fileio.varian.read(fid_file='fid', procpar_file='procpar')
D = shape(data)
print "done!\n"

### Echo-antiecho -> states
if sys.argv[1] == 'EAE':
    print "\nConverting 'Echo-antiecho' to 'States'...",
    data1 = zeros(D, dtype=complex)
    i = 0
    while i < D[0]:
        data1[i, :] = (data[i, :] + data[i + 1, :])/2
        data1[i + 1, :] = 1j*(-data[i, :] + data[i + 1, :])/2
        i = i + 2
    data = data1 
    print "done!\n"
elif sys.argv[1] == 'S':
    print "\n'States' mode assumed."
else:
    print "\nNo quadrature given! Script terminated."
    sys.exit() 

### Direct dimesion phasing
print "\nPhasing direct dimension...",
p0_dir = float(sys.argv[2]) 
p0_indir = float(sys.argv[3])
data = data*exp(-1j*p0_dir*pi/180) # complex numbers correspond to direct dimension
print "done!\n"

### Indirect dimension FT
print "\nIndirect dimension: phasing and FT...",
data_r_indirect = real(data)[0::2, :] + 1j*real(data)[1::2, :] # 'artificial' complex numbers in indirect dimension
data_r_indirect = data_r_indirect*exp(-1j*p0_indir*pi/180) # phasing indirect dimension
data_r_indirect = fft.fft(data_r_indirect, axis = 0) # fft with axis=0: "over rows", along columns
data_r_indirect = vstack((real(data_r_indirect), imag(data_r_indirect)))

data_i_indirect = imag(data)[0::2, :] + 1j*imag(data)[1::2, :]
data_i_indirect = data_i_indirect*exp(-1j*p0_indir*pi/180) 
data_i_indirect = fft.fft(data_i_indirect, axis = 0)
data_i_indirect = vstack((real(data_i_indirect), imag(data_i_indirect)))

data = data_r_indirect + 1j*data_i_indirect # first half rows are 'cos', second half - 'sin'. Complex numbers correspond to direct dimension.
print "done!\n"

### Schedule
print "\nReading sampling scheme...",
try:
    sched = loadtxt(sys.argv[4], delimiter='-')
except ValueError:
    print "\nWrong schedule format! Script terminated."
    sys.exit()     
start = sched[:, 0]
end = sched[:, 1]
print "done!\n"

ind = ()
for i in range(len(sched)):
    ind = hstack((ind, arange(start[i], end[i] + 1)))
I = len(ind)
ind = ind.astype(int) 
k = 0
ind_alternated = zeros(2*I)
for i in range(I):
    ind_alternated[k] = 2*ind[i]
    ind_alternated[k + 1] = ind_alternated[k] + 1
    k = k + 2
ind_alternated = ind_alternated.astype(int)

print "\nPreparing schedule for .mdd files...",
regions = 1
dimensions = 2
line1 = str(dimensions) + ' 1 ' + str(regions*2*I) + '\n'
line2 = str(regions) + ' ' + str(2*D[1])
s = ''
for i in range(2*I):
    s = s + '\n' + str(0) + ' ' + str(ind_alternated[i])
print "done!\n"

folder = os.getcwd() + '/MDD' 
if not os.path.exists(folder):
    os.makedirs(folder)
os.chdir(folder)

### CS solver
print "\nReconstructing gaps row-wise...",
def rec(i):
    print 'row ', i
    fid = empty(I, dtype=complex)
    fid_alternated = empty(2*I)
    cs = empty(2*D[1])
    fid = data[i, ind]
    fid_alternated[0::2] = real(fid)
    fid_alternated[1::2] = imag(fid)
    f = ''
    for j in range(2*I):
        f = f + '\n' + str(fid_alternated[j])
    contents = 'mdd asc sparse f180.0 \n ./MDD/region01.mdd \n MDD sparse\n $ \n' + line1 + line2 + s + f # contains the data to be written to .mdd file

    name = str(i + 1)
    file = open(name + '.mdd', 'w')
    file.write(contents)
    file.close()
    
    command = 'cssolver ' + name + ' CS_alg=' + sys.argv[5] + ' CS_niter=' + sys.argv[6] + 'CS_VE=' + sys.argv[7] + 'MDD_NOISE=0 > ./' + name + '.log'
    command = "tcsh -c  '" + command + "'"
    os.system(command) # executes .mdd file

    ### Reading .CS file
    file = open(name + '.cs', 'rb')
    for j in range(2*D[1]):
        cs[j] = float(struct.unpack('f', file.read(4))[0])
    file.close()
    return (-1j*cs[0::2] + cs[1::2]) # output of CS solver

output = zeros(D, dtype=complex)
import multiprocessing as mp
pool = mp.Pool()
results = [pool.apply_async(rec, args=(i,)) for i in range(D[0])]
output = [i.get() for i in results]
print "done!\n"

### Indirect dimension iFT
print "\nIndirect dimension iFT...",
fid_r_indirect = real(output)[:D[0]/2, :] + 1j*real(output)[D[0]/2:, :] # 'artificial' complex numbers in indirect dimension
fid_r_indirect = fft.ifft(fid_r_indirect, axis = 0) # fft with axis=0: "over rows", along columns
fid_r_indirect = fid_r_indirect*exp(1j*p0_indir*pi/180) # "dephasing"

fid_i_indirect = imag(output)[:D[0]/2, :] + 1j*imag(output)[D[0]/2:, :]
fid_i_indirect = fft.ifft(fid_i_indirect, axis = 0)
fid_i_indirect = fid_i_indirect*exp(1j*p0_indir*pi/180)

fid_rec = zeros(D, dtype=complex)
fid_rec[0::2, :] = real(fid_r_indirect) + 1j*real(fid_i_indirect)
fid_rec[1::2, :] = imag(fid_r_indirect) + 1j*imag(fid_i_indirect)
print "done!\n"

if sys.argv[1] == 'EAE':
    print "\nConverting 'States' to 'Echo-antiecho'...",
    fid_rec1 = zeros(D, dtype=complex)
    i = 0
    while i < D[0]:
        fid_rec1[i, :] = fid_rec[i + 1, :] - 1j*fid_rec[i, :]
        fid_rec1[i + 1, :] = fid_rec[i + 1, :] + 1j*fid_rec[i, :]
        i = i + 2
    fid_rec = fid_rec1 
    print "done!\n"

### Wrinting to .FID file
print "\nMaking fid_rec file..."
os.chdir("..")
ng.fileio.varian.write(os.getcwd(), dic, fid_rec, fid_file='fid_rec', procpar_file='procpar_rec', overwrite=True)
print "done!\n"
